{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Convert And Inference Pytorch model with CustomOps\n",
    "\n",
    "With onnxruntime_extensions package, the Pytorch model with the operation cannot be converted into the standard ONNX operators still be converted and the converted ONNX model still can be run with ONNXRuntime, plus onnxruntime_extensions package. This tutorial show it works"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Converting\n",
    "Suppose there is a model which cannot be converted because there is no matrix inverse operation in ONNX standard opset. And the model will be defined like the following."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "class CustomInverse(torch.nn.Module):\n",
    "    def forward(self, x):\n",
    "        return torch.inverse(x) + x"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "To export this model into ONNX format, we need register a custom op handler for pytorch.onn.exporter."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "from torch.onnx import register_custom_op_symbolic\n",
    "\n",
    "\n",
    "def my_inverse(g, self):\n",
    "    return g.op(\"ai.onnx.contrib::Inverse\", self)\n",
    "\n",
    "register_custom_op_symbolic('::inverse', my_inverse, 1)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Then, invoke the exporter"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "import io\n",
    "import onnx\n",
    "\n",
    "x0 = torch.randn(3, 3)\n",
    "# Export model to ONNX\n",
    "f = io.BytesIO()\n",
    "t_model = CustomInverse()\n",
    "torch.onnx.export(t_model, (x0, ), f, opset_version=12)\n",
    "onnx_model = onnx.load(io.BytesIO(f.getvalue()))"
   ],
   "outputs": [],
   "metadata": {
    "scrolled": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now, we got a ONNX model in the memory, and it can be save into a disk file by 'onnx.save_model(onnx_model, <file_path>)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Inference\n",
    "This converted model cannot directly run the onnxruntime due to the custom operator. but it can run with onnxruntime_extensions easily.\n",
    "\n",
    "Firstly, let define a PyOp function to inteprete the custom op node in the ONNNX model."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "import numpy\n",
    "from onnxruntime_extensions import onnx_op, PyOp\n",
    "@onnx_op(op_type=\"Inverse\")\n",
    "def inverse(x):\n",
    "    # the user custom op implementation here:\n",
    "    return numpy.linalg.inv(x)\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "* **ONNX Inference**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "from onnxruntime_extensions import PyOrtFunction\n",
    "onnx_fn = PyOrtFunction.from_model(onnx_model)\n",
    "y = onnx_fn(x0.numpy())\n",
    "print(y)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[[-3.081008    0.20269153  0.42009977]\n",
      " [-3.3962293   2.5986686   2.4447646 ]\n",
      " [ 0.7805753  -0.20394287 -2.7528977 ]]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "* **Compare the result with Pytorch**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "t_y = t_model(x0)\n",
    "numpy.testing.assert_almost_equal(t_y, y, decimal=5)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Implement the customop in C++ (optional)\n",
    "To make the ONNX model with the CustomOp runn on all other language supported by ONNX Runtime and be independdent of Python, a C++ implmentation is needed, check here for the [inverse.hpp](../operators/math/inverse.hpp) for an example on how to do that."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "from onnxruntime_extensions import enable_py_op\n",
    "# disable the PyOp function and run with the C++ function\n",
    "enable_py_op(False)\n",
    "y = onnx_fn(x0.numpy())\n",
    "print(y)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[[-3.081008    0.20269153  0.42009977]\n",
      " [-3.3962293   2.5986686   2.4447646 ]\n",
      " [ 0.7805753  -0.20394287 -2.7528977 ]]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}